-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][SVE2] Use rshrnb for masked stores #70026
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Matthew Devereau (MDevereau) ChangesThis patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in: C = A + B is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store. Full diff: https://github.com/llvm/llvm-project/pull/70026.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a16a102e472e709..09ab4ddacddf138 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21017,6 +21017,19 @@ static SDValue performMSTORECombine(SDNode *N,
}
}
+ if (MST->isTruncatingStore()){
+ if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)){
+ EVT ValueVT = Value->getValueType(0);
+ EVT MemVT = MST->getMemoryVT();
+ if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) ||
+ (ValueVT == MVT::nxv4i32 && MemVT == MVT::nxv4i16) ||
+ (ValueVT == MVT::nxv2i64 && MemVT == MVT::nxv2i32)){
+ return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(), MST->getOffset(), MST->getMask(),
+ MST->getMemoryVT(), MST->getMemOperand(), MST->getAddressingMode(), true);
+ }
+ }
+ }
+
return SDValue();
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index a913177623df9ec..0afd11d098a0009 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
store <vscale x 2 x i16> %3, ptr %4, align 1
ret void
}
+
+define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i1> %mask) { ; preds = %vector.body, %vector.ph
+; CHECK-LABEL: masked_store_rshrnb:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT: rshrnb z0.b, z0.h, #6
+; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2]
+; CHECK-NEXT: ret
+ %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
+ %1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+ %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+ %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
+ %4 = getelementptr inbounds i8, ptr %dst, i64 %index
+ tail call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> %3, ptr %4, i32 1, <vscale x 8 x i1> %mask)
+ ret void
+}
+
+declare void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8>, ptr, i32, <vscale x 8 x i1>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in: C = A + B D = C >> Shift is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this @MDevereau - looks like another nice improvement! I just have one comment about potentially refactoring the code ...
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) { | ||
EVT ValueVT = Value->getValueType(0); | ||
EVT MemVT = MST->getMemoryVT(); | ||
if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you have similar checks in performSTORECombine. It would be nice to combine them somehow into a helper function and reuse the logic, perhaps something like
bool isHalvingTruncateOfLegalScalableType(MVT SrcVT, MVT DstVT) {
return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
(SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
(SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
}
Also, I think I'd prefer we do this check before calling trySimplifySrlAddToRshrnb to prevent us doing unnecessary extra work and potentially creating nodes that we throw away. I realise we also do this in performSTORECombine, but it would be great to re-order the checks in that function too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the changes @MDevereau. :)
This patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in: C = A + B D = C >> Shift is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store.
This patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in:
C = A + B
D = C >> Shift
is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store.